# -*- encoding: utf-8 -*-
"""
@File    : fun_expand_dims.py
@Author  : lilong
@Time    : 2022/3/13 11:24 上午
"""


"""
参考：https://blog.csdn.net/u014769320/article/details/99696898
"""

from keras import backend as K
import numpy as np

x = np.array(
    [[2, 3],
     [2, 3],
     [2, 3]]
)

print("x shape:", x.shape)
y1 = K.expand_dims(x, 0)
y2 = K.expand_dims(x, 1)
y3 = K.expand_dims(x, 2)
y4 = K.expand_dims(x, -1)

print("y1 shape:", y1.shape, y1)
print("y2 shape:", y2.shape, y2)
print("y3 shape:", y3.shape, y3)
print("y4 shape:", y4.shape, y4)
